In [3]:
from tensorflow import keras
from tensorflow.keras import layers
import pathlib
from tensorflow.keras.utils import image_dataset_from_directory

import pandas as pd
import pathlib
from pathlib import Path

import numpy as np
import pandas as pd

# plotting modules
from matplotlib import pyplot as plt
import seaborn as sns

from sklearn.model_selection import train_test_split

import plotly as plotly
plotly.offline.init_notebook_mode()

from tensorflow import keras
from tensorflow.keras import layers

import tensorflow as tf
from keras.utils import to_categorical
from keras.models import load_model
from PIL import Image

import plotly.graph_objects as go
from tensorflow.keras.models import Sequential
from keras.callbacks import ModelCheckpoint
from tensorflow.keras.layers import Dense
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, classification_report, precision_recall_curve, ConfusionMatrixDisplay

Framing the Problem¶

The goal of this notebook is to implement one of the features of our proposed product that can be solved by machine learning.

This notebook is for Plant Disease Identification.

We will try to use a convolutional neural network to predict the disease on the plant.

Getting the data¶

The dataset that we will be training our models on is gotten from Kaggle - https://www.kaggle.com/datasets/sadmansakibmahi/plant-disease-expert/data

In [4]:
data_folder = pathlib.Path("../../../../Downloads/plant disease/Image Data base/Image Data base")
In [5]:
def create_image_dataframe(folder):
    data = {'ImagePath': [], 'ClassLabel': []}
    for class_folder in folder.iterdir():
        if class_folder.is_dir():
            for img_path in class_folder.iterdir():
                data['ImagePath'].append(img_path)
                data['ClassLabel'].append(class_folder.name)
    return pd.DataFrame(data)
In [6]:
df = create_image_dataframe(data_folder)
In [7]:
df
Out[7]:
ImagePath ClassLabel
0 ..\..\..\..\Downloads\plant disease\Image Data... algal leaf in tea
1 ..\..\..\..\Downloads\plant disease\Image Data... algal leaf in tea
2 ..\..\..\..\Downloads\plant disease\Image Data... algal leaf in tea
3 ..\..\..\..\Downloads\plant disease\Image Data... algal leaf in tea
4 ..\..\..\..\Downloads\plant disease\Image Data... algal leaf in tea
... ... ...
199660 ..\..\..\..\Downloads\plant disease\Image Data... Waterlogging in plant
199661 ..\..\..\..\Downloads\plant disease\Image Data... Waterlogging in plant
199662 ..\..\..\..\Downloads\plant disease\Image Data... Waterlogging in plant
199663 ..\..\..\..\Downloads\plant disease\Image Data... Waterlogging in plant
199664 ..\..\..\..\Downloads\plant disease\Image Data... Waterlogging in plant

199665 rows × 2 columns

As we can see from the above, we have 199,665 images in all the folders combined.

The class labels represent each of the folders we have in our downloaded folder.

In [8]:
len(df['ClassLabel'].unique())
Out[8]:
58

58 classes, that is 58 different kind of 'states' of some set of plants. 'States' and not diseases because of the plants in the dataset are healthy.

In [9]:
df['ClassLabel'].value_counts()
Out[9]:
ClassLabel
Orange Haunglongbing Citrus greening           52872
Grape Esca Black Measles                       13284
Soybean healthy                                12216
Grape Black rot                                11328
Grape Leaf blight Isariopsis Leaf Spot         10332
Tomato Bacterial spot                           6381
Apple Apple scab                                6048
Apple Black rot                                 5964
Tomato Late blight                              5727
Tomato Septoria leaf spot                       5313
Tomato Spider mites Two spotted spider mite     5028
Tomato Target Spot                              4212
Apple healthy                                   3948
Common Rust in corn Leaf                        3918
Tomato healthy                                  3819
Blueberry healthy                               3606
Pepper bell healthy                             3549
Blight in corn Leaf                             3438
Potato Early blight                             3000
Tomato Early blight                             3000
Potato Late blight                              3000
Pepper bell Bacterial spot                      2988
Tomato Leaf Mold                                2856
Corn (maize) healthy                            2790
Strawberry Leaf scorch                          2664
Apple Cedar apple rust                          2640
Cherry (including sour) Powdery mildew          2526
Cherry (including_sour) healthy                 2052
Gray Leaf Spot in corn Leaf                     1722
Tomato Tomato mosaic virus                      1119
Strawberry healthy                              1095
Grape healthy                                   1017
Raspberry healthy                                891
Peach healthy                                    864
red leaf spot in tea                             429
Potato healthy                                   366
algal leaf in tea                                339
brown blight in tea                              339
corn crop                                        312
bird eye spot in tea                             300
anthracnose in tea                               300
cabbage looper                                   234
healthy tea leaf                                 222
Cercospora leaf spot                             189
lemon canker                                     183
potato hollow heart                              180
Garlic                                           147
ginger                                           135
Brown spot in rice leaf                          120
Leaf smut in rice leaf                           120
Bacterial leaf blight in rice leaf               120
potato crop                                      120
Sogatella rice                                    78
onion                                             60
tomato canker                                     57
potassium deficiency in plant                     54
Nitrogen deficiency in plant                      33
Waterlogging in plant                             21
Name: count, dtype: int64

Exploratory Data Analysis¶

Let us put our dataset to view.

In [10]:
import matplotlib.pyplot as plt
from PIL import Image

grouped = df.groupby('ClassLabel')

# Plot the first 5 images for each class
for label, group in grouped:
    print(label)
    print('=====================')
    
    # Create a figure with 5 subplots (1 row, 5 columns)
    fig, axes = plt.subplots(1, 5, figsize=(20, 4))
    
    # Iterate over the first 5 rows of the group
    for i in range(min(5, len(group))):  # This ensures we don't go out of bounds if there are less than 5 images
        img_path = group['ImagePath'].iloc[i]
        img = Image.open(img_path)
        
        # Plot the image in the i-th subplot
        axes[i].imshow(img)
        axes[i].set_title(label)
        axes[i].axis('off')
    
    # Hide any unused subplots if the group has less than 5 images
    if len(group) < 5:
        for j in range(len(group), 5):
            axes[j].axis('off')
    
    plt.show()
Apple Apple scab
=====================
Apple Black rot
=====================
Apple Cedar apple rust
=====================
Apple healthy
=====================
Bacterial leaf blight in rice leaf
=====================
Blight in corn Leaf
=====================
Blueberry healthy
=====================
Brown spot in rice leaf
=====================
Cercospora leaf spot
=====================
Cherry (including sour) Powdery mildew
=====================
Cherry (including_sour) healthy
=====================
Common Rust in corn Leaf
=====================
Corn (maize) healthy
=====================
Garlic
=====================
Grape Black rot
=====================
Grape Esca Black Measles
=====================
Grape Leaf blight Isariopsis Leaf Spot
=====================
Grape healthy
=====================
Gray Leaf Spot in corn Leaf
=====================
Leaf smut in rice leaf
=====================
Nitrogen deficiency in plant
=====================
Orange Haunglongbing Citrus greening
=====================
Peach healthy
=====================
Pepper bell Bacterial spot
=====================
Pepper bell healthy
=====================
Potato Early blight
=====================
Potato Late blight
=====================
Potato healthy
=====================
Raspberry healthy
=====================
Sogatella rice
=====================
Soybean healthy
=====================
Strawberry Leaf scorch
=====================
Strawberry healthy
=====================
Tomato Bacterial spot
=====================
Tomato Early blight
=====================
Tomato Late blight
=====================
Tomato Leaf Mold
=====================
Tomato Septoria leaf spot
=====================
Tomato Spider mites Two spotted spider mite
=====================
Tomato Target Spot
=====================
Tomato Tomato mosaic virus
=====================
Tomato healthy
=====================
Waterlogging in plant
=====================
algal leaf in tea
=====================
anthracnose in tea
=====================
bird eye spot in tea
=====================
brown blight in tea
=====================
cabbage looper
=====================
corn crop
=====================
ginger
=====================
healthy tea leaf
=====================
lemon canker
=====================
onion
=====================
potassium deficiency in plant
=====================
potato crop
=====================
potato hollow heart
=====================
red leaf spot in tea
=====================
tomato canker
=====================

Modeling¶

Before we train our model, we need to prepare our images to the format that the model expects.

In [145]:
df['ImagePath'] = df['ImagePath'].astype(str)
grouped = df.groupby('ClassLabel')

train_df = pd.DataFrame()
val_df = pd.DataFrame()
test_df = pd.DataFrame()

for _, group in grouped:
    # Split into train and temporary test
    train_tmp, test_tmp = train_test_split(group, test_size=0.3, random_state=42)  # 70% train, 30% temp test
    
    # Split the temporary test into actual validation and test
    val_tmp, test_final = train_test_split(test_tmp, test_size=0.5, random_state=42)  # Split 30% into 15% val, 15% test

    # Append to the respective DataFrames
    train_df = pd.concat([train_df, train_tmp])
    val_df = pd.concat([val_df, val_tmp])
    test_df = pd.concat([test_df, test_final])

# Now, you have train_df, val_df, and test_df
print(f"Training Set: {train_df.shape[0]} samples")
print(f"Validation Set: {val_df.shape[0]} samples")
print(f"Test Set: {test_df.shape[0]} samples")
Training Set: 139743 samples
Validation Set: 29950 samples
Test Set: 29972 samples

Here, we are applying some data augmentation to the dataset to help improve the performance of the model.

We apply some rotation to the images, move the width and heirght around a bit, also flipping them and inverting them.

Another very important step you can see below is that we have rescaled the rgb values of the images from 0 - 256 to values just between 0 and 1. This is an important step because it makes sure that the model doesn't assign more weights to the pixels with very high values.

In [146]:
from tensorflow.keras.preprocessing.image import ImageDataGenerator

train_datagen = ImageDataGenerator(
    rescale=1./255,
    featurewise_center=False,  # set input mean to 0 over the dataset
    samplewise_center=False,  # set each sample mean to 0
    featurewise_std_normalization=False,  # divide inputs by std of the dataset
    samplewise_std_normalization=False,  # divide each input by its std
    zca_whitening=False,  # apply ZCA whitening
    rotation_range=20,  # randomly rotate images in the range (degrees, 0 to 180)
    width_shift_range=0.2,  # randomly shift images horizontally (fraction of total width)
    height_shift_range=0.2,  # randomly shift images vertically (fraction of total height)
    horizontal_flip=True,  # randomly flip images
    vertical_flip=True
)

val_datagen = ImageDataGenerator(rescale=1./255)
In [148]:
train_generator = train_datagen.flow_from_dataframe(
    dataframe=train_df,
    x_col="ImagePath",
    y_col="ClassLabel",
    batch_size=32,
    seed=42,
    shuffle=True,
    class_mode="categorical",
    target_size=(150,150))

valid_generator = val_datagen.flow_from_dataframe(
    dataframe=val_df,
    x_col="ImagePath",
    y_col="ClassLabel",
    batch_size=32,
    seed=42,
    shuffle=True,
    class_mode="categorical",
    target_size=(150,150))

test_generator = val_datagen.flow_from_dataframe(
    dataframe=test_df,
    x_col="ImagePath",
    y_col="ClassLabel",
    batch_size=32,
    seed=42,
    shuffle=False,
    class_mode="categorical",
    target_size=(150,150))
Found 139743 validated image filenames belonging to 58 classes.
Found 29950 validated image filenames belonging to 58 classes.
Found 29972 validated image filenames belonging to 58 classes.

Convolutional Neural Network Model¶

With tensorflow and keras, we will be using a fully connected neural network for our model.

In [149]:
from tensorflow import keras
from tensorflow.keras import layers

inputs = keras.Input(shape=(150, 150, 3))
x = layers.Conv2D(filters=32, kernel_size=3, activation="relu")(inputs)
x = layers.MaxPooling2D(pool_size=2)(x)
x = layers.Conv2D(filters=64, kernel_size=3, activation="relu")(x)
x = layers.MaxPooling2D(pool_size=2)(x)
x = layers.Conv2D(filters=128, kernel_size=3, activation="relu")(x)
x = layers.MaxPooling2D(pool_size=2)(x)
x = layers.Conv2D(filters=256, kernel_size=3, activation="relu")(x)
x = layers.MaxPooling2D(pool_size=2)(x)
x = layers.Conv2D(filters=256, kernel_size=3, activation="relu")(x)
x = layers.Flatten()(x)
outputs = layers.Dense(58, activation="softmax")(x)  # Adjusted for 58 classes with softmax activation
model = keras.Model(inputs=inputs, outputs=outputs)

Model Explainability¶

Model Architecture and Training Overview¶

We've constructed a neural network model using TensorFlow and Keras, designed specifically for image classification. This model processes images sized 150x150 pixels with 3 color channels (RGB).

Model Layers:

The model starts with a series of convolutional layers, where each layer is designed to detect specific features in the image like edges, textures, or more complex patterns as we go deeper.

Each convolutional layer is followed by a max-pooling layer, which reduces the spatial dimensions of the image representations, helping to make the detection of features invariant to scale and orientation changes.

After multiple layers of convolutions and pooling, the high-level understanding of the images is flattened into a vector that serves as input to a fully connected layer.

Output Layer:

The final layer is a dense layer with 58 units, corresponding to the number of categories we want to classify. It uses the softmax activation function to output probabilities for each class, indicating the likelihood of the image belonging to each class.

Model Compilation:

The model uses the Adam optimizer for adjusting weights, which is effective and efficient for this kind of problem. It minimizes a function called categorical crossentropy, a common choice for classification tasks, which measures the difference between the predicted probabilities and the actual distribution of the labels. We track the accuracy during training as a straightforward metric to understand how well the model performs.

Training Process Enhancements¶

To optimize training and avoid common pitfalls:

ModelCheckpoint: Saves the best model as we train, ensuring that we always have the version of the model that performed best on the validation set, in case later iterations perform worse.

EarlyStopping: Monitors the model's performance on a validation set and stops training if the model's performance doesn't improve for 10 consecutive epochs. This prevents overfitting and unnecessary computation by stopping when the model isn't learning anymore.

ReduceLROnPlateau: Reduces the learning rate if the validation loss stops improving. Smaller steps in weight updates can lead to better fine-tuning and better overall model performance.

Execution¶

The model is trained using the specified training and validation datasets for 15 epochs, but training can stop early if no improvement is seen as monitored by our callbacks. This setup ensures that the training is both efficient and effective, adapting to the data as needed.

In [150]:
model.summary()
Model: "model_18"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_33 (InputLayer)       [(None, 150, 150, 3)]     0         
                                                                 
 conv2d_288 (Conv2D)         (None, 148, 148, 32)      896       
                                                                 
 max_pooling2d_16 (MaxPoolin  (None, 74, 74, 32)       0         
 g2D)                                                            
                                                                 
 conv2d_289 (Conv2D)         (None, 72, 72, 64)        18496     
                                                                 
 max_pooling2d_17 (MaxPoolin  (None, 36, 36, 64)       0         
 g2D)                                                            
                                                                 
 conv2d_290 (Conv2D)         (None, 34, 34, 128)       73856     
                                                                 
 max_pooling2d_18 (MaxPoolin  (None, 17, 17, 128)      0         
 g2D)                                                            
                                                                 
 conv2d_291 (Conv2D)         (None, 15, 15, 256)       295168    
                                                                 
 max_pooling2d_19 (MaxPoolin  (None, 7, 7, 256)        0         
 g2D)                                                            
                                                                 
 conv2d_292 (Conv2D)         (None, 5, 5, 256)         590080    
                                                                 
 flatten_18 (Flatten)        (None, 6400)              0         
                                                                 
 dense_42 (Dense)            (None, 58)                371258    
                                                                 
=================================================================
Total params: 1,349,754
Trainable params: 1,349,754
Non-trainable params: 0
_________________________________________________________________
In [151]:
model.compile(
    optimizer='adam',
    loss='categorical_crossentropy',
    metrics=['accuracy']
)
In [152]:
# Define callbacks
from tensorflow.keras import optimizers, callbacks

my_callbacks = [
    callbacks.ModelCheckpoint(filepath='./models/vanilla.h5', save_best_only=True),
    callbacks.EarlyStopping(monitor='val_loss', patience=10, verbose=1),
    callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, verbose=1, min_lr=1e-5)
]
In [153]:
history = model.fit(
    train_generator,
    validation_data=valid_generator,
    epochs=15,
    callbacks=my_callbacks
)
Epoch 1/15
4367/4367 [==============================] - 545s 125ms/step - loss: 1.0177 - accuracy: 0.7088 - val_loss: 0.7434 - val_accuracy: 0.7914 - lr: 0.0010
Epoch 2/15
4367/4367 [==============================] - 568s 130ms/step - loss: 0.3204 - accuracy: 0.8986 - val_loss: 0.2596 - val_accuracy: 0.9178 - lr: 0.0010
Epoch 3/15
4367/4367 [==============================] - 544s 125ms/step - loss: 0.2275 - accuracy: 0.9271 - val_loss: 0.2920 - val_accuracy: 0.9094 - lr: 0.0010
Epoch 4/15
4367/4367 [==============================] - 523s 120ms/step - loss: 0.1923 - accuracy: 0.9384 - val_loss: 0.1396 - val_accuracy: 0.9573 - lr: 0.0010
Epoch 5/15
4367/4367 [==============================] - 759s 174ms/step - loss: 0.1724 - accuracy: 0.9457 - val_loss: 0.1897 - val_accuracy: 0.9451 - lr: 0.0010
Epoch 6/15
4367/4367 [==============================] - 841s 193ms/step - loss: 0.1577 - accuracy: 0.9502 - val_loss: 0.1475 - val_accuracy: 0.9559 - lr: 0.0010
Epoch 7/15
4367/4367 [==============================] - 839s 192ms/step - loss: 0.1463 - accuracy: 0.9535 - val_loss: 0.1541 - val_accuracy: 0.9541 - lr: 0.0010
Epoch 8/15
4367/4367 [==============================] - 836s 191ms/step - loss: 0.1447 - accuracy: 0.9541 - val_loss: 0.2411 - val_accuracy: 0.9296 - lr: 0.0010
Epoch 9/15
4367/4367 [==============================] - 832s 191ms/step - loss: 0.1378 - accuracy: 0.9567 - val_loss: 0.1329 - val_accuracy: 0.9605 - lr: 0.0010
Epoch 10/15
4367/4367 [==============================] - 833s 191ms/step - loss: 0.1361 - accuracy: 0.9581 - val_loss: 0.1902 - val_accuracy: 0.9438 - lr: 0.0010
Epoch 11/15
4367/4367 [==============================] - 836s 191ms/step - loss: 0.1341 - accuracy: 0.9588 - val_loss: 0.1241 - val_accuracy: 0.9627 - lr: 0.0010
Epoch 12/15
4367/4367 [==============================] - 842s 193ms/step - loss: 0.1275 - accuracy: 0.9610 - val_loss: 0.1314 - val_accuracy: 0.9613 - lr: 0.0010
Epoch 13/15
4367/4367 [==============================] - 832s 191ms/step - loss: 0.1291 - accuracy: 0.9603 - val_loss: 0.1109 - val_accuracy: 0.9685 - lr: 0.0010
Epoch 14/15
4367/4367 [==============================] - 831s 190ms/step - loss: 0.1277 - accuracy: 0.9612 - val_loss: 0.1205 - val_accuracy: 0.9649 - lr: 0.0010
Epoch 15/15
4367/4367 [==============================] - 829s 190ms/step - loss: 0.1294 - accuracy: 0.9605 - val_loss: 0.1047 - val_accuracy: 0.9690 - lr: 0.0010
In [154]:
history_df = pd.DataFrame(history.history)
history_df.insert(0, 'epoch', range(1, len(history_df) + 1))
history_df
Out[154]:
epoch loss accuracy val_loss val_accuracy lr
0 1 1.017690 0.708801 0.743360 0.791386 0.001
1 2 0.320429 0.898642 0.259642 0.917830 0.001
2 3 0.227511 0.927102 0.291983 0.909416 0.001
3 4 0.192348 0.938365 0.139631 0.957329 0.001
4 5 0.172427 0.945736 0.189749 0.945109 0.001
5 6 0.157686 0.950223 0.147521 0.955927 0.001
6 7 0.146324 0.953543 0.154100 0.954124 0.001
7 8 0.144717 0.954059 0.241131 0.929583 0.001
8 9 0.137810 0.956699 0.132919 0.960467 0.001
9 10 0.136100 0.958109 0.190158 0.943840 0.001
10 11 0.134150 0.958760 0.124125 0.962704 0.001
11 12 0.127487 0.961007 0.131421 0.961302 0.001
12 13 0.129079 0.960306 0.110887 0.968548 0.001
13 14 0.127737 0.961172 0.120532 0.964875 0.001
14 15 0.129431 0.960542 0.104676 0.969048 0.001
In [155]:
# Create a DataFrame from the history object
history_df = pd.DataFrame(history.history)

# Plot the training and validation loss
plt.figure(figsize=(9, 5))
values = history_df['accuracy']
epochs = range(1, len(values) + 1)
plt.plot(epochs, history_df['loss'], 'bo', label='Training loss')
plt.plot(epochs, history_df['val_loss'], 'ro', label='Validation loss')

plt.xlabel('Epochs')
plt.xticks(epochs)
plt.ylabel('Loss')
plt.legend()
plt.title('Training and validation loss')
plt.show()

# Plot the training and validation accuracy
plt.figure(figsize=(9, 5))
plt.plot(epochs, history_df['accuracy'], 'bo', label='Training accuracy')
plt.plot(epochs, history_df['val_accuracy'], 'ro', label='Validation accuracy')

plt.xlabel('Epochs')
plt.xticks(epochs)
plt.ylabel('Accuracy')
plt.legend()
plt.title('Training and validation accuracy')
plt.show()

Interpreting the Model Performance¶

After training the model over 15 epochs, we were able to get an accuracy score of 96.9% on our validation dataset.

Evaluating the Model on the Test Data¶

In [156]:
# evaluate the model
model.evaluate(test_generator)
937/937 [==============================] - 57s 61ms/step - loss: 0.1024 - accuracy: 0.9693
Out[156]:
[0.102357417345047, 0.969338059425354]
In [157]:
# predict the model
y_pred = model.predict(test_generator)

# get the class with the highest probability
y_pred_labels = np.argmax(y_pred, axis=1)

# get the true class
y_true = np.array(test_generator.classes)


# get the class labels
class_labels = list(test_generator.class_indices.keys())

display(classification_report(y_true, y_pred_labels, target_names=class_labels, zero_division=0))
# get the accuracy
accuracy = accuracy_score(y_true, y_pred_labels)
print(f'Accuracy: {accuracy}')
937/937 [==============================] - 50s 54ms/step
                                             precision    recall  f1-score   support

                           Apple Apple scab       1.00      0.98      0.99       908
                            Apple Black rot       0.98      1.00      0.99       895
                     Apple Cedar apple rust       0.98      0.98      0.98       396
                              Apple healthy       0.97      0.96      0.97       593
         Bacterial leaf blight in rice leaf       1.00      0.89      0.94        18
                        Blight in corn Leaf       0.88      0.84      0.86       516
                          Blueberry healthy       0.98      0.99      0.99       541
                    Brown spot in rice leaf       0.85      0.94      0.89        18
                       Cercospora leaf spot       0.77      0.69      0.73        29
     Cherry (including sour) Powdery mildew       0.99      0.98      0.99       379
            Cherry (including_sour) healthy       0.96      1.00      0.98       308
                   Common Rust in corn Leaf       0.97      0.95      0.96       588
                       Corn (maize) healthy       1.00      0.99      0.99       419
                                     Garlic       0.62      0.87      0.73        23
                            Grape Black rot       0.98      0.98      0.98      1700
                   Grape Esca Black Measles       0.97      1.00      0.99      1993
     Grape Leaf blight Isariopsis Leaf Spot       1.00      0.95      0.97      1550
                              Grape healthy       0.99      0.98      0.98       153
                Gray Leaf Spot in corn Leaf       0.72      0.88      0.79       259
                     Leaf smut in rice leaf       1.00      0.83      0.91        18
               Nitrogen deficiency in plant       0.50      0.20      0.29         5
       Orange Haunglongbing Citrus greening       1.00      1.00      1.00      7931
                              Peach healthy       0.98      0.96      0.97       130
                 Pepper bell Bacterial spot       0.95      0.98      0.96       449
                        Pepper bell healthy       0.97      0.95      0.96       533
                        Potato Early blight       0.93      1.00      0.96       450
                         Potato Late blight       0.96      0.92      0.94       450
                             Potato healthy       0.78      0.91      0.84        55
                          Raspberry healthy       0.89      0.98      0.93       134
                             Sogatella rice       0.45      0.42      0.43        12
                            Soybean healthy       0.98      0.99      0.99      1833
                     Strawberry Leaf scorch       0.89      0.98      0.93       400
                         Strawberry healthy       1.00      0.93      0.97       165
                      Tomato Bacterial spot       1.00      0.96      0.98       958
                        Tomato Early blight       0.96      0.86      0.91       450
                         Tomato Late blight       0.92      0.95      0.94       860
                           Tomato Leaf Mold       0.98      0.95      0.96       429
                  Tomato Septoria leaf spot       0.95      0.97      0.96       797
Tomato Spider mites Two spotted spider mite       0.97      0.91      0.94       755
                         Tomato Target Spot       0.88      0.94      0.91       632
                 Tomato Tomato mosaic virus       1.00      0.96      0.98       168
                             Tomato healthy       0.99      0.98      0.99       573
                      Waterlogging in plant       1.00      0.25      0.40         4
                          algal leaf in tea       0.89      0.92      0.90        51
                         anthracnose in tea       0.88      0.51      0.65        45
                       bird eye spot in tea       0.66      0.91      0.77        45
                        brown blight in tea       0.94      0.92      0.93        51
                             cabbage looper       0.63      0.53      0.58        36
                                  corn crop       0.72      0.87      0.79        47
                                     ginger       0.50      0.76      0.60        21
                           healthy tea leaf       0.94      1.00      0.97        34
                               lemon canker       0.64      0.50      0.56        28
                                      onion       0.83      0.56      0.67         9
              potassium deficiency in plant       0.67      0.67      0.67         9
                                potato crop       0.62      0.44      0.52        18
                        potato hollow heart       0.63      0.44      0.52        27
                       red leaf spot in tea       0.95      0.97      0.96        65
                              tomato canker       0.17      0.33      0.22         9

                                   accuracy                           0.97     29972
                                  macro avg       0.87      0.84      0.85     29972
                               weighted avg       0.97      0.97      0.97     29972

Accuracy: 0.9693380488455892

Above, we can see how the model performed. We have an accuracy of 96.9% on the test data. And then we can also see how it performed on each class.

As expected, there is a better performance on the classes with more data, and the dataset is somewhat imbalanced.

But an impressive performance none the less.

In [164]:
model.save('./models/plant-disease-vanilla.h5')

Future Work¶

Going forward, and in an attempt to improve this machine learning process, there are a few things that we can consider doing.

  1. Improving the dataset. More images to train on means that the model can learn better and perform even better.

  2. Balancing the dataset. One way that will definitely improve our model performance would be to get more images for the classes where we currently have just few of them.

  3. Using even more advanced models for transfer learning. In addition to what has been mentioned, we could also use popular models like VGG16, VGG19, ResNet50, Inception and the likes. But this calls for more GPU processing power.

RestNet

In [158]:
from tensorflow.keras.applications import ResNet50
resnet_base = ResNet50(weights='imagenet', include_top=False, input_shape=(150,150,3))
In [159]:
resnet_base.trainable = False
In [160]:
inputs = keras.Input(shape=(150, 150, 3))
x = keras.applications.resnet50.preprocess_input(inputs)
x = resnet_base(x)
x = layers.Flatten()(x)
x = layers.Dense(256)(x)
x = layers.Dropout(0.5)(x)
outputs = layers.Dense(58, activation="softmax")(x)
model_resnet = keras.Model(inputs, outputs)
In [161]:
for layer in resnet_base.layers[:143]:
   layer.trainable = False
for layer in resnet_base.layers[143:]:
   layer.trainable = True
In [162]:
model_resnet.compile(
    optimizer=keras.optimizers.RMSprop(lr=2e-5),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

callbacks = [
    keras.callbacks.ModelCheckpoint(
        filepath="./models/finetune-resnet.keras",
        save_best_only=True,
        monitor="val_loss")
]
c:\Users\Owner\.conda\envs\tf4gpu\lib\site-packages\keras\optimizers\optimizer_v2\rmsprop.py:140: UserWarning:

The `lr` argument is deprecated, use `learning_rate` instead.

In [163]:
history_resnet = model_resnet.fit(
    train_generator,
    validation_data=valid_generator,
    epochs=1,
    callbacks=my_callbacks
)
4367/4367 [==============================] - 839s 191ms/step - loss: 2.5297 - accuracy: 0.4090 - val_loss: 1.8988 - val_accuracy: 0.5369 - lr: 2.0000e-05